# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm_ffi
import tvm
from tvm import te
import numpy as np


def test_array():
    a = tvm.runtime.convert([1, 2, 3])
    assert len(a) == 3
    assert a[-1] == 3
    a_slice = a[-3:-1]
    assert (a_slice[0], a_slice[1]) == (1, 2)


def test_array_save_load_json():
    a = tvm.runtime.convert([1, 2, 3.5, True])
    json_str = tvm.ir.save_json(a)
    a_loaded = tvm.ir.load_json(json_str)
    assert a_loaded[1] == 2
    assert a_loaded[2] == 3.5
    assert a_loaded[3] == True
    assert isinstance(a_loaded[3], bool)


def test_dir_array():
    a = tvm.runtime.convert([1, 2, 3])
    assert dir(a)


def test_map():
    a = te.var("a")
    b = te.var("b")
    amap = tvm.runtime.convert({a: 2, b: 3})
    assert a in amap
    assert len(amap) == 2
    dd = dict(amap.items())
    assert a in dd
    assert b in dd
    assert a + 1 not in amap
    assert {x for x in amap} == {a, b}
    assert set(amap.keys()) == {a, b}
    assert set(amap.values()) == {2, 3}


def test_str_map():
    amap = tvm.runtime.convert({"a": 2, "b": 3})
    assert "a" in amap
    assert len(amap) == 2
    dd = dict(amap.items())
    assert amap["a"] == 2
    assert "a" in dd
    assert "b" in dd


def test_map_save_load_json():
    a = te.var("a")
    b = te.var("b")
    amap = tvm.runtime.convert({a: 2, b: 3})
    json_str = tvm.ir.save_json(amap)
    amap = tvm.ir.load_json(json_str)
    assert len(amap) == 2
    dd = {kv[0].name: kv[1] for kv in amap.items()}
    assert dd == {"a": 2, "b": 3}


def test_dir_map():
    a = te.var("a")
    b = te.var("b")
    amap = tvm.runtime.convert({a: 2, b: 3})
    assert dir(amap)


def test_getattr_map():
    a = te.var("a")
    b = te.var("b")
    amap = tvm.runtime.convert({a: 2, b: 3})
    assert isinstance(amap, tvm_ffi.Map)


def test_in_container():
    arr = tvm.runtime.convert(["a", "b", "c"])
    assert "a" in arr
    assert tvm.tir.StringImm("a") in arr
    assert "d" not in arr


def test_tensor_container():
    x = tvm.runtime.tensor([1, 2, 3])
    arr = tvm.runtime.convert([x, x])
    assert arr[0].same_as(x)
    assert arr[1].same_as(x)
    assert isinstance(arr[0], tvm.runtime.Tensor)


def test_return_variant_type():
    func = tvm.get_global_func("testing.ReturnsVariant")
    res_even = func(42)
    assert isinstance(res_even, tvm.tir.IntImm)
    assert res_even == 21

    res_odd = func(17)
    assert res_odd == "argument was odd"


def test_pass_variant_type():
    func = tvm.get_global_func("testing.AcceptsVariant")

    assert func("string arg") == "ffi.String"
    assert func(17) == "ir.IntImm"


def test_pass_incorrect_variant_type():
    func = tvm.get_global_func("testing.AcceptsVariant")
    float_arg = tvm.tir.FloatImm("float32", 0.5)

    with pytest.raises(Exception):
        func(float_arg)


if __name__ == "__main__":
    tvm.testing.main()
